import soapy
import aotools

import os
import sys
import numpy as np
import matplotlib.pyplot as plt


work_path = os.getcwd()
config_file = work_path + '/configuration_dir/generate_data.yaml'

sim = soapy.Sim(config_file)
sim.aoinit()
sim.makeIMat()

# The result index for processing is stored in these two directories
result_index_for_diff_r0 = {'0.06': 0, '0.08': 3, '0.1': 2, '0.12': 0, '0.14': 2,
                            '0.16': 7, '0.18': 0, '0.2': 0}
result_index_for_diff_mag = {'9': 3, '10': 0, '11': 2, '12': 2, '14': 1, '15': 6,
                             '16': 0, '17': 6}

# Make necessary variables for loading data for testing
inter_CNN_nor_path = work_path + '/result/detailed_soapy_test_result/CNN/normalized/5'
inter_CNN_unnor_path = work_path + '/result/detailed_soapy_test_result/CNN/unnormalized/0'
inter_Vi_nor_path = work_path + '/result/detailed_soapy_test_result/Vision Transformer/normalized/12'
inter_Vi_unnor_path = work_path + '/result/detailed_soapy_test_result/Vision Transformer/unnormalized/12'

inter_model_path = [inter_CNN_nor_path, inter_CNN_unnor_path, inter_Vi_nor_path, inter_Vi_unnor_path]


# make necessary variables for saving data after processing
inter_CNN_nor_path_result = work_path + '/result/azimuthal_average_of_image/CNN/normalized/5'
inter_CNN_unnor_path_result = work_path + '/result/azimuthal_average_of_image/CNN/unnormalized/0'
inter_Vi_nor_path_result = work_path + '/result/azimuthal_average_of_image/Vision Transformer/normalized/12'
inter_Vi_unnor_path_result = work_path + '/result/azimuthal_average_of_image/Vision Transformer/unnormalized/12'

inter_model_path_result = [inter_CNN_nor_path_result,
                           inter_CNN_unnor_path_result,
                           inter_Vi_nor_path_result,
                           inter_Vi_unnor_path_result]


# Define function for converting variable name to string format
def get_variable_name(variable):
    global_direc = globals()
    for name in global_direc:
        if global_direc[name] == variable:
            return name


# define python exception when two or more maximum numbers are founded in the sub-image
class NumOfMaxError(Exception):
    def __init__(self):
        super().__init__('The count of maximum number in the sub-image was found more then one!')


# find the number of occurrence of a number in a image
def count_number(matrix, num):
    s = 0
    length = matrix.shape[0]
    for i in range(length):
        for j in range(length):
            if matrix[i][j] == num:
                s += 1
    return s


# define function that make symmetrical expanding of azimuthal information of psf
def symmetrical_expand(azimuthal_average):
    size = azimuthal_average.shape[0]
    result = np.empty((2 * size, ))
    pointer = 0
    for i in azimuthal_average:
        result[size + pointer] = i
        result[size - 1 - pointer] = i
        pointer += 1
    return result


# define the half size of sub-area of camera image
half_area_size = 60


# record information of the best psf
image = sim.sciCams[0].bestPSF
if count_number(image, np.max(image)) != 1:
    raise NumOfMaxError
max_index = np.unravel_index(np.argmax(image), image.shape)
sub_image = image[
            max_index[0] - half_area_size: max_index[0] + half_area_size,
            max_index[1] - half_area_size: max_index[1] + half_area_size
            ]
image_azimuthal_average = aotools.image_processing.azimuthal_average(sub_image)
symmetrical_azimuthal_average = symmetrical_expand(image_azimuthal_average)
# define file path for the saving of perfect azimuthal average information
bestPSF_path = work_path + '/result/azimuthal_average_of_image/bestPSF'
np.save(bestPSF_path + '/sub_image',
        sub_image,
        allow_pickle=False,
        fix_imports=False
        )
np.save(bestPSF_path + '/image_azimuthal_average',
        image_azimuthal_average,
        allow_pickle=False,
        fix_imports=False
        )
np.save(bestPSF_path + '/symmetrical_azimuthal_average',
        symmetrical_azimuthal_average,
        allow_pickle=False,
        fix_imports=False
        )

# process information under other circumstance
for model_path in inter_model_path:
    # First, load the image data of different r0
    for r0_number in result_index_for_diff_r0:
        # prepare variable name for image under specific r0
        model_name = get_variable_name(model_path)
        inter_image_name = '_'.join([model_name.split('_')[1], model_name.split('_')[2]])
        image_name = inter_image_name + '_r0=' + r0_number
        # load image data under this specific r0
        image_path = model_path + '/different_r0' + '/r0=' + r0_number + '/'\
            + str(result_index_for_diff_r0[r0_number]) + '/image_after.npy'
        image_value = np.load(image_path)
        # set variables for the storing of the image data
        setattr(sys.modules[__name__], image_name, image_value)

        # make further processing for image
        image = globals()[image_name]
        if count_number(image, np.max(image)) != 1:
            raise NumOfMaxError
        max_index = np.unravel_index(np.argmax(image), image.shape)
        sub_image = image[
                    max_index[0] - half_area_size: max_index[0] + half_area_size,
                    max_index[1] - half_area_size: max_index[1] + half_area_size
                    ]
        image_azimuthal_average = aotools.image_processing.azimuthal_average(sub_image)
        symmetrical_azimuthal_average = symmetrical_expand(image_azimuthal_average)

        # save azimuthal average information of image to directory
        if model_name.split('_')[1] == 'CNN':
            index = 0
        else:
            index = 2
        if model_name.split('_')[2] != 'nor':
            index += 1

        model_path_result = inter_model_path_result[index] + '/different_r0/r0='\
            + r0_number + '/' + str(result_index_for_diff_r0[r0_number])

        np.save(model_path_result + '/sub_image',
                sub_image,
                allow_pickle=False,
                fix_imports=False
                )
        np.save(model_path_result + '/image_azimuthal_average',
                image_azimuthal_average,
                allow_pickle=False,
                fix_imports=False
                )
        np.save(model_path_result + '/symmetrical_azimuthal_average',
                symmetrical_azimuthal_average,
                allow_pickle=False,
                fix_imports=False
                )

        # plot azimuthal average of image
        title = '_'.join([model_name.split('_')[1], model_name.split('_')[2]]) + '_r0=' + r0_number
        plt.plot(image_azimuthal_average)
        plt.title(title)
        plt.show()

    # Second, load the image data of different magnitude
    for mag_number in result_index_for_diff_mag:
        # prepare variable name for image under specific magnitude
        image_name = inter_image_name + '_mag=' + str(mag_number)
        # load image data under this specific magnitude
        image_path = model_path + '/different_mag' + '/star_mag=' + mag_number + '/'\
            + str(result_index_for_diff_mag[mag_number]) + '/image_after.npy'
        image_value = np.load(image_path)
        # set variables for the storing of the image data
        setattr(sys.modules[__name__], image_name, image_value)

        # make further processing for image
        image = globals()[image_name]
        if count_number(image, np.max(image)) != 1:
            raise NumOfMaxError
        max_index = np.unravel_index(np.argmax(image), image.shape)
        sub_image = image[
                    max_index[0] - half_area_size: max_index[0] + half_area_size,
                    max_index[1] - half_area_size: max_index[1] + half_area_size
                    ]
        image_azimuthal_average = aotools.image_processing.azimuthal_average(sub_image)
        symmetrical_azimuthal_average = symmetrical_expand(image_azimuthal_average)

        # save azimuthal average information of image to directory
        if model_name.split('_')[1] == 'CNN':
            index = 0
        else:
            index = 2
        if model_name.split('_')[2] != 'nor':
            index += 1

        model_path_result = inter_model_path_result[index] + '/different_mag/star_mag=' \
            + mag_number + '/' + str(result_index_for_diff_mag[mag_number])

        np.save(model_path_result + '/sub_image',
                sub_image,
                allow_pickle=False,
                fix_imports=False
                )
        np.save(model_path_result + '/image_azimuthal_average',
                image_azimuthal_average,
                allow_pickle=False,
                fix_imports=False
                )
        np.save(model_path_result + '/symmetrical_azimuthal_average',
                symmetrical_azimuthal_average,
                allow_pickle=False,
                fix_imports=False
                )

        # plot azimuthal average of image
        title = '_'.join([model_name.split('_')[1], model_name.split('_')[2]]) + '_mag=' + mag_number
        plt.plot(image_azimuthal_average)
        plt.title(title)
        plt.show()
